{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "jupytext": {
      "cell_metadata_filter": "-all",
      "main_language": "python",
      "notebook_metadata_filter": "-all"
    },
    "colab": {
      "name": "generated-notebook.ipynb",
      "provenance": []
    },
    "language_info": {
      "name": "python"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a2eddf1e"
      },
      "source": [
        "Before running, install required packages:"
      ],
      "id": "a2eddf1e"
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fcdc34bc",
        "outputId": "b841c73e-4355-4b2b-cc7d-648db45f8401"
      },
      "source": [
        "! pip install numpy torch torchvision pytorch-ignite"
      ],
      "id": "fcdc34bc",
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (1.19.5)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.9.0+cu102)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.10.0+cu102)\n",
            "Collecting pytorch-ignite\n",
            "  Downloading pytorch_ignite-0.4.6-py3-none-any.whl (232 kB)\n",
            "\u001b[K     |████████████████████████████████| 232 kB 40.4 MB/s \n",
            "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.7.4.3)\n",
            "Requirement already satisfied: pillow>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.1.2)\n",
            "Installing collected packages: pytorch-ignite\n",
            "Successfully installed pytorch-ignite-0.4.6\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "4993df1a"
      },
      "source": [
        "---"
      ],
      "id": "4993df1a"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c715bf41"
      },
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "from torch import optim, nn\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "from torchvision import models, datasets, transforms\n",
        "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
        "from ignite.metrics import Accuracy, Loss\n",
        "import urllib\n",
        "import zipfile"
      ],
      "id": "c715bf41",
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "f651183a"
      },
      "source": [
        "# COMMENT THIS OUT IF YOU USE YOUR OWN DATA.\n",
        "# Download example data into ./data/image-data (4 image files, 2 for \"dog\", 2 for \"cat\").\n",
        "url = \"https://github.com/jrieke/traingenerator/raw/main/data/fake-image-data.zip\"\n",
        "zip_path, _ = urllib.request.urlretrieve(url)\n",
        "with zipfile.ZipFile(zip_path, \"r\") as f:\n",
        "    f.extractall(\"data\")"
      ],
      "id": "f651183a",
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "362f2df7"
      },
      "source": [
        "# Setup"
      ],
      "id": "362f2df7"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b2c49478"
      },
      "source": [
        "# INSERT YOUR DATA HERE\n",
        "# Expected format: One folder per class, e.g.\n",
        "# train\n",
        "# --- dogs\n",
        "# |   +-- lassie.jpg\n",
        "# |   +-- komissar-rex.png\n",
        "# --- cats\n",
        "# |   +-- garfield.png\n",
        "# |   +-- smelly-cat.png\n",
        "#\n",
        "# Example: https://github.com/jrieke/traingenerator/tree/main/data/image-data\n",
        "train_data = \"data/image-data\"  # required\n",
        "val_data = \"data/image-data\"    # optional\n",
        "test_data = None                # optional"
      ],
      "id": "b2c49478",
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0b99bb18"
      },
      "source": [
        "# Set up hyperparameters.\n",
        "lr = 0.001\n",
        "batch_size = 16\n",
        "num_epochs = 10"
      ],
      "id": "0b99bb18",
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dbdcab26"
      },
      "source": [
        "# Set up logging.\n",
        "print_every = 10  # batches"
      ],
      "id": "dbdcab26",
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "4de6273b"
      },
      "source": [
        "# Set up device.\n",
        "use_cuda = torch.cuda.is_available()\n",
        "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
      ],
      "id": "4de6273b",
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "6af21daa"
      },
      "source": [
        "# Preprocessing"
      ],
      "id": "6af21daa"
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 1,
        "id": "8803fce4"
      },
      "source": [
        "import torchvision.transforms.functional as F\n",
        "\n",
        "class SquarePad:\n",
        "\tdef __call__(self, image):\n",
        "\t\tw, h = image.size\n",
        "\t\tmax_wh = np.max([w, h])\n",
        "\t\thp = int((max_wh - w) / 2)\n",
        "\t\tvp = int((max_wh - h) / 2)\n",
        "\t\tpadding = (hp, vp, hp, vp)\n",
        "\t\treturn F.pad(image, padding, 114, 'constant')\n",
        "\n",
        "  \n",
        "def preprocess(data, name):\n",
        "    if data is None:  # val/test can be empty\n",
        "        return None\n",
        "\n",
        "    # Read image files to pytorch dataset.\n",
        "    transform = transforms.Compose([\n",
        "        SquarePad(),\n",
        "        transforms.Grayscale(),\n",
        "        transforms.Resize(128), \n",
        "        transforms.CenterCrop(128), \n",
        "        transforms.ToTensor(), \n",
        "        \n",
        "    ])\n",
        "    dataset = datasets.ImageFolder(data, transform=transform)\n",
        "\n",
        "    # Wrap in data loader.\n",
        "    if use_cuda:\n",
        "        kwargs = {\"pin_memory\": True, \"num_workers\": 1}\n",
        "    else:\n",
        "        kwargs = {}\n",
        "    loader = DataLoader(dataset, batch_size=batch_size, shuffle=(name==\"train\"), **kwargs)\n",
        "    return loader"
      ],
      "id": "8803fce4",
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "c38179a9"
      },
      "source": [
        "train_loader = preprocess(train_data, \"train\")\n",
        "val_loader = preprocess(val_data, \"val\")\n",
        "test_loader = preprocess(test_data, \"test\")"
      ],
      "id": "c38179a9",
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "da07871a"
      },
      "source": [
        "# Model"
      ],
      "id": "da07871a"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v2CLBso_3nWq"
      },
      "source": [
        "# Define model\n",
        "class NeuralNetwork(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(NeuralNetwork, self).__init__()\n",
        "        self.sequential = nn.Sequential(\n",
        "            nn.Linear(16384,100),# 输入为 128x128 的 灰度图\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(100,2),\n",
        "            nn.Softmax(dim=1)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x.view(-1,16384)\n",
        "        return self.sequential(x)\n",
        "\n"
      ],
      "id": "v2CLBso_3nWq",
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 2,
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3cb2211e",
        "outputId": "b1bcef83-d558-4126-d9a9-b74be24afed8"
      },
      "source": [
        "# Set up model, loss, optimizer.\n",
        "model = NeuralNetwork()\n",
        "model = model.to(device)\n",
        "loss_func = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=lr)\n",
        "print(model)"
      ],
      "id": "3cb2211e",
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "NeuralNetwork(\n",
            "  (sequential): Sequential(\n",
            "    (0): Linear(in_features=16384, out_features=100, bias=True)\n",
            "    (1): ReLU()\n",
            "    (2): Linear(in_features=100, out_features=2, bias=True)\n",
            "    (3): Softmax(dim=1)\n",
            "  )\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "lines_to_next_cell": 2,
        "id": "25c3efa1"
      },
      "source": [
        "# Training"
      ],
      "id": "25c3efa1"
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 1,
        "id": "2cc6d605"
      },
      "source": [
        "# Set up pytorch-ignite trainer and evaluator.\n",
        "trainer = create_supervised_trainer(\n",
        "    model,\n",
        "    optimizer,\n",
        "    loss_func,\n",
        "    device=device,\n",
        ")\n",
        "metrics = {\n",
        "    \"accuracy\": Accuracy(),\n",
        "    \"loss\": Loss(loss_func),\n",
        "}\n",
        "evaluator = create_supervised_evaluator(\n",
        "    model, metrics=metrics, device=device\n",
        ")"
      ],
      "id": "2cc6d605",
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 1,
        "id": "f3e766cc"
      },
      "source": [
        "@trainer.on(Events.ITERATION_COMPLETED(every=print_every))\n",
        "def log_batch(trainer):\n",
        "    batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1\n",
        "    print(\n",
        "        f\"Epoch {trainer.state.epoch} / {num_epochs}, \"\n",
        "        f\"batch {batch} / {trainer.state.epoch_length}: \"\n",
        "        f\"loss: {trainer.state.output:.3f}\"\n",
        "    )"
      ],
      "id": "f3e766cc",
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "lines_to_next_cell": 1,
        "id": "5e481c7e"
      },
      "source": [
        "@trainer.on(Events.EPOCH_COMPLETED)\n",
        "def log_epoch(trainer):\n",
        "    print(f\"Epoch {trainer.state.epoch} / {num_epochs} average results: \")\n",
        "\n",
        "    def log_results(name, metrics, epoch):\n",
        "        print(\n",
        "            f\"{name + ':':6} loss: {metrics['loss']:.3f}, \"\n",
        "            f\"accuracy: {metrics['accuracy']:.3f}\"\n",
        "        )\n",
        "\n",
        "    # Train data.\n",
        "    evaluator.run(train_loader)\n",
        "    log_results(\"train\", evaluator.state.metrics, trainer.state.epoch)\n",
        "    \n",
        "    # Val data.\n",
        "    if val_loader:\n",
        "        evaluator.run(val_loader)\n",
        "        log_results(\"val\", evaluator.state.metrics, trainer.state.epoch)\n",
        "\n",
        "    # Test data.\n",
        "    if test_loader:\n",
        "        evaluator.run(test_loader)\n",
        "        log_results(\"test\", evaluator.state.metrics, trainer.state.epoch)\n",
        "\n",
        "    print()\n",
        "    print(\"-\" * 80)\n",
        "    print()"
      ],
      "id": "5e481c7e",
      "execution_count": 42,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "68e29e44"
      },
      "source": [
        "# Start training.\n",
        "trainer.run(train_loader, max_epochs=num_epochs)"
      ],
      "id": "68e29e44",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VMsGSH973-A0"
      },
      "source": [
        "model.eval()\n",
        "image = torch.randn(1,1,128,128).to(device)\n",
        "torch.onnx.export(model,image,\"model.onnx\",input_names=[\"data\"],output_names=[\"output\"])"
      ],
      "id": "VMsGSH973-A0",
      "execution_count": 45,
      "outputs": []
    }
  ]
}